


# ================================================
# 0) Install & load required packages
# ================================================
# install.packages(c("haven", "broom", "dplyr", "ggplot2"))  # if not already installed
library(haven)    # read_dta()
library(broom)    # tidy(), conf.int
library(dplyr)    # data‐wrangling
library(ggplot2)  # plotting

# ================================================
# 1) Load data and add a country indicator
# ================================================
df <- read_dta("merged_file_nomissings.dta") %>%
  filter(!is.na(id_se) | !is.na(id_dk)) %>%
  mutate(country = case_when(
    !is.na(id_se) ~ "Sweden",
    !is.na(id_dk) ~ "Denmark"
  ))

# ================================================
# 2) Define outcomes and subgroup variables + labels
# ================================================
outcomes <- c("marry_n", "traits_n", "dislike_n")
outcome_labels <- c(
  marry_n    = "Social Distancing",
  traits_n   = "Trait Stereotyping",
  dislike_n  = "Party Dislike"
)

# ================================================
# Create binary subgroup variables (0 = below mean, 1 = above mean)
# ================================================

df <- df %>%
  mutate(
    attach_n_d      = ifelse(attach_n      > mean(attach_n,      na.rm = TRUE), 1, 0),
    ident_n_d       = ifelse(ident_n       > mean(ident_n,       na.rm = TRUE), 1, 0),
    proud_n_d       = ifelse(proud_n       > mean(proud_n,       na.rm = TRUE), 1, 0),
    pride_inst_n_d  = ifelse(pride_inst_n  > mean(pride_inst_n,  na.rm = TRUE), 1, 0),
    pride_cult_n_d  = ifelse(pride_cult_n  > mean(pride_cult_n,  na.rm = TRUE), 1, 0)
  )

subgroups <- c("attach_n_d", "ident_n_d", "proud_n_d",
               "pride_inst_n_d", "pride_cult_n_d")
sub_labels <- c(
  attach_n_d      = "Attachment",
  ident_n_d       = "Identification",
  proud_n_d       = "General Pride",
  pride_inst_n_d  = "Institutional Pride",
  pride_cult_n_d  = "Cultural Pride"
)


# ================================================
# 3) Fit interaction models and extract 1vs0 effect
# ================================================
results <- list()
for (ctry in c("Sweden", "Denmark")) {
  for (y in outcomes) {
    for (sg in subgroups) {
      dat  <- filter(df, country == ctry)
      mod  <- lm( as.formula(paste0(y, " ~ treat_flag * ", sg)), data = dat )
      trm  <- paste0("treat_flag:", sg)
      tt   <- tidy(mod, conf.int = TRUE) %>%
        filter(term == trm) %>%
        transmute(
          country   = ctry,
          outcome   = y,
          subgroup  = sg,
          estimate  = estimate,
          conf.low  = conf.low,
          conf.high = conf.high,
          p.value   = p.value
        )
      results[[paste(ctry, y, sg, sep = "_")]] <- tt
    }
  }
}
diff_df <- bind_rows(results) %>%
  mutate(
    subgroup_label = factor(sub_labels[subgroup], levels = unname(sub_labels)),
    country        = factor(country, levels = c("Sweden", "Denmark")),
    outcome        = factor(outcome, levels = outcomes),
    significant    = p.value < 0.05
  )

# ================================================
# 4) Combined plot: one column per country, one row per outcome
# ================================================
# Ensure outcome_label exists
diff_df <- diff_df %>%
  mutate(
    outcome_label = recode_factor(
      outcome,
      !!!outcome_labels
    )
  )

# Plot without coloring by significance
ggplot(diff_df, aes(x = estimate, y = subgroup_label)) +
  geom_vline(xintercept = 0, linetype = "dashed") +
  geom_errorbarh(aes(xmin = conf.low, xmax = conf.high), height = 0) +
  geom_point(size = 3) +
  facet_grid(
    rows   = vars(outcome_label),
    cols   = vars(country),
    scales = "free_y"
  ) +
  labs(
    x = "Conditional Marginal Means (Flag)",
    y = NULL
  ) +
  theme_minimal() +
  theme(
    strip.text       = element_text(face = "bold"),
    panel.grid.minor = element_blank(),
    axis.text.y      = element_text(size = 11)
  )




###################################################################

#CAKE

#################################################################



# ================================================
# 3) Fit interaction models and extract 1vs0 effect
# ================================================
results <- list()
for (ctry in c("Sweden", "Denmark")) {
  for (y in outcomes) {
    for (sg in subgroups) {
      dat  <- filter(df, country == ctry)
      mod  <- lm( as.formula(paste0(y, " ~ treat_cake * ", sg)), data = dat )
      trm  <- paste0("treat_cake:", sg)
      tt   <- tidy(mod, conf.int = TRUE) %>%
        filter(term == trm) %>%
        transmute(
          country   = ctry,
          outcome   = y,
          subgroup  = sg,
          estimate  = estimate,
          conf.low  = conf.low,
          conf.high = conf.high,
          p.value   = p.value
        )
      results[[paste(ctry, y, sg, sep = "_")]] <- tt
    }
  }
}
diff_df <- bind_rows(results) %>%
  mutate(
    subgroup_label = factor(sub_labels[subgroup], levels = unname(sub_labels)),
    country        = factor(country, levels = c("Sweden", "Denmark")),
    outcome        = factor(outcome, levels = outcomes),
    significant    = p.value < 0.05
  )

# ================================================
# 4) Combined plot: one column per country, one row per outcome
# ================================================
# Ensure outcome_label exists
diff_df <- diff_df %>%
  mutate(
    outcome_label = recode_factor(
      outcome,
      !!!outcome_labels
    )
  )

# Plot without coloring by significance
ggplot(diff_df, aes(x = estimate, y = subgroup_label)) +
  geom_vline(xintercept = 0, linetype = "dashed") +
  geom_errorbarh(aes(xmin = conf.low, xmax = conf.high), height = 0) +
  geom_point(size = 3) +
  facet_grid(
    rows   = vars(outcome_label),
    cols   = vars(country),
    scales = "free_y"
  ) +
  labs(
    x = "Conditional Marginal Means (Cake)",
    y = NULL
  ) +
  theme_minimal() +
  theme(
    strip.text       = element_text(face = "bold"),
    panel.grid.minor = element_blank(),
    axis.text.y      = element_text(size = 11)
  )
